import os
from random import random

import numpy as np
import torch
from datetime import datetime

from natsort import natsorted
from numba import njit
from tqdm import tqdm
from torch.multiprocessing import Pool, Process, set_start_method

from dn_nn.utils import get_data_from_slowfast_output_pkl, get_cuda_status_as_device, get_torch_float_tensor, \
    save_data_to_pickle, calculate_evaluation_metrics, create_directory, print_and_save_eval_metrics, init_data, \
    get_date_as_string

debug = None

@torch.no_grad()
def sample_one_instance(model, cnn_predictions, last_sample, var_sequence, device):
    """
    This function is used to get one sample
    """
    this_sample = last_sample
    this_sample, this_sample_prob_1 = model.forward_sampling(this_sample, cnn_predictions=cnn_predictions, var_sequence=var_sequence, device=device)
    this_sample_prob_1 = this_sample_prob_1.detach().cpu().numpy()
    return this_sample_prob_1, this_sample


def sample(num_samples, num_true_labels, models, val_loader, device):
    """
    This function is used to sample given the outputs of the CNN model
    """
    var_sequence = np.arange(num_true_labels)
    outputs = []
    true_labels = []
    for cnn_predictions, labels in tqdm(val_loader):
        batch_size = cnn_predictions.shape[0]
        this_random_sample = np.random.binomial(n=1, p=0.5, size=[batch_size, num_true_labels])
        this_random_sample = torch.FloatTensor(this_random_sample).to(device)
        this_batch_samples_probs_sum = np.zeros((batch_size, num_true_labels))
        cnn_predictions = cnn_predictions.to(device)
        labels = labels.to(device)
        for _ in range(num_samples):
            np.random.shuffle(var_sequence)
            this_sample_prob_1, this_random_sample = sample_one_instance(models, cnn_predictions,
                                                                         this_random_sample, var_sequence, device)
            this_batch_samples_probs_sum += this_sample_prob_1
        this_sample_estimate = this_batch_samples_probs_sum / num_samples
        outputs.append(this_sample_estimate)
        true_labels.append(labels.detach().cpu().numpy())
    return outputs, true_labels